{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 08: PyTorch Profiling\n",
    "\n",
    "This notebook is an experiment to try out the PyTorch profiler.\n",
    "\n",
    "See here for more:\n",
    "* https://pytorch.org/blog/introducing-pytorch-profiler-the-new-and-improved-performance-tool/\n",
    "* https://pytorch.org/docs/stable/profiler.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torch import nn\n",
    "from torchvision import transforms, datasets\n",
    "from torchinfo import summary\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from going_modular import data_setup, engine"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'cuda'"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "device"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get and load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data/10_whole_foods dir exists, skipping download\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import requests\n",
    "from zipfile import ZipFile\n",
    "\n",
    "def get_food_image_data():\n",
    "    if not os.path.exists(\"data/10_whole_foods\"):\n",
    "        os.makedirs(\"data/\", exist_ok=True)\n",
    "        # Download data\n",
    "        data_url = \"https://storage.googleapis.com/food-vision-image-playground/10_whole_foods.zip\"\n",
    "        print(f\"Downloading data from {data_url}...\")\n",
    "        requests.get(data_url)\n",
    "        # Unzip data\n",
    "        targ_dir = \"data/10_whole_foods\"\n",
    "        print(f\"Extracting data to {targ_dir}...\")\n",
    "        with ZipFile(\"10_whole_foods.zip\") as zip_ref:\n",
    "            zip_ref.extractall(targ_dir)\n",
    "    else:\n",
    "        print(\"data/10_whole_foods dir exists, skipping download\")\n",
    "\n",
    "get_food_image_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<torch.utils.data.dataloader.DataLoader at 0x7f052ab26e20>,\n",
       " <torch.utils.data.dataloader.DataLoader at 0x7f05299eed00>,\n",
       " ['apple',\n",
       "  'banana',\n",
       "  'beef',\n",
       "  'blueberries',\n",
       "  'carrots',\n",
       "  'chicken_wings',\n",
       "  'egg',\n",
       "  'honey',\n",
       "  'mushrooms',\n",
       "  'strawberries'])"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Setup dirs\n",
    "train_dir = \"data/10_whole_foods/train\"\n",
    "test_dir = \"data/10_whole_foods/test\"\n",
    "\n",
    "# Setup ImageNet normalization levels (turns all images into similar distribution as ImageNet)\n",
    "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                                 std=[0.229, 0.224, 0.225])\n",
    "\n",
    "# Create starter transform\n",
    "simple_transform = transforms.Compose([\n",
    "    transforms.Resize((224, 224)),\n",
    "    transforms.ToTensor(),\n",
    "    normalize\n",
    "])           \n",
    "\n",
    "# Create data loaders\n",
    "train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(\n",
    "    train_dir=train_dir,\n",
    "    test_dir=test_dir,\n",
    "    transform=simple_transform,\n",
    "    batch_size=32,\n",
    "    num_workers=8\n",
    ")\n",
    "\n",
    "train_dataloader, test_dataloader, class_names"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load model "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = torchvision.models.efficientnet_b0(pretrained=True).to(device)\n",
    "# model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Update the classifier\n",
    "model.classifier = torch.nn.Sequential(\n",
    "    nn.Dropout(p=0.2),\n",
    "    nn.Linear(1280, len(class_names)).to(device))\n",
    "\n",
    "# Freeze all base layers \n",
    "for param in model.features.parameters():\n",
    "    param.requires_grad = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train model and track results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define loss and optimizer\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Adjust training function to track results with `SummaryWriter`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'EfficietNetB0'"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.name = \"EfficietNetB0\"\n",
    "model.name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.tensorboard import SummaryWriter\n",
    "from going_modular.engine import train_step, test_step\n",
    "from tqdm import tqdm\n",
    "writer = SummaryWriter()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Update the `train_step()` function to include the PyTorch profiler."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_step(model, dataloader, loss_fn, optimizer):\n",
    "    model.train()\n",
    "    train_loss, train_acc = 0, 0\n",
    "    ## NEW: Add PyTorch profiler\n",
    "\n",
    "    dir_to_save_logs = os.path.join(\"logs\", datetime.now().strftime(\"%Y-%m-%d-%H-%M\"))\n",
    "    with torch.profiler.profile(\n",
    "        on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name=dir_to_save_logs),\n",
    "        # with_stack=True # this adds a lot of overhead to training (tracing all the stack)\n",
    "    ):\n",
    "        for batch, (X, y) in enumerate(dataloader):\n",
    "            # Send data to GPU\n",
    "            X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)\n",
    "            \n",
    "            # Turn on mixed precision if available\n",
    "            with torch.autocast(device_type=device, enabled=True):\n",
    "                # 1. Forward pass\n",
    "                y_pred = model(X)\n",
    "\n",
    "                # 2. Calculate loss\n",
    "                loss = loss_fn(y_pred, y)\n",
    "\n",
    "            # 3. Optimizer zero grad\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            # 4. Loss backward\n",
    "            loss.backward()\n",
    "\n",
    "            # 5. Optimizer step\n",
    "            optimizer.step()\n",
    "\n",
    "            # 6. Calculate metrics\n",
    "            train_loss += loss.item()\n",
    "            y_pred_class = torch.softmax(y_pred, dim=1).argmax(dim=1)\n",
    "            # print(f\"y: \\n{y}\\ny_pred_class:{y_pred_class}\")\n",
    "            # print(f\"y argmax: {y_pred.argmax(dim=1)}\")\n",
    "            # print(f\"Equal: {(y_pred_class == y)}\")\n",
    "            train_acc += (y_pred_class == y).sum().item() / len(y_pred)\n",
    "            # print(f\"batch: {batch} train_acc: {train_acc}\")\n",
    "\n",
    "    # Adjust returned metrics\n",
    "    return train_loss / len(dataloader), train_acc / len(dataloader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "TK - Now to use the writer, we've got to adjust the `train()` function..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(\n",
    "    model,\n",
    "    train_dataloader,\n",
    "    test_dataloader,\n",
    "    optimizer,\n",
    "    loss_fn=nn.CrossEntropyLoss(),\n",
    "    epochs=5,\n",
    "):\n",
    "\n",
    "    results = {\"train_loss\": [], \"train_acc\": [], \"test_loss\": [], \"test_acc\": []}\n",
    "\n",
    "    for epoch in tqdm(range(epochs)):\n",
    "        train_loss, train_acc = train_step(\n",
    "            model=model,\n",
    "            dataloader=train_dataloader,\n",
    "            loss_fn=loss_fn,\n",
    "            optimizer=optimizer,\n",
    "        )\n",
    "        test_loss, test_acc = test_step(\n",
    "            model=model, dataloader=test_dataloader, loss_fn=loss_fn\n",
    "        )\n",
    "\n",
    "        # Print out what's happening\n",
    "        print(\n",
    "            f\"Epoch: {epoch+1} | \"\n",
    "            f\"train_loss: {train_loss:.4f} | \"\n",
    "            f\"train_acc: {train_acc:.4f} | \"\n",
    "            f\"test_loss: {test_loss:.4f} | \"\n",
    "            f\"test_acc: {test_acc:.4f}\"\n",
    "        )\n",
    "\n",
    "        # Update results\n",
    "        results[\"train_loss\"].append(train_loss)\n",
    "        results[\"train_acc\"].append(train_acc)\n",
    "        results[\"test_loss\"].append(test_loss)\n",
    "        results[\"test_acc\"].append(test_acc)\n",
    "\n",
    "        # Add results to SummaryWriter\n",
    "        writer.add_scalars(main_tag=\"Loss\", \n",
    "                           tag_scalar_dict={\"train_loss\": train_loss,\n",
    "                                            \"test_loss\": test_loss},\n",
    "                           global_step=epoch)\n",
    "        writer.add_scalars(main_tag=\"Accuracy\", \n",
    "                           tag_scalar_dict={\"train_acc\": train_acc,\n",
    "                                            \"test_acc\": test_acc}, \n",
    "                           global_step=epoch)\n",
    "    \n",
    "    # Close the writer\n",
    "    writer.close()\n",
    "\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 1/5 [00:05<00:21,  5.27s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1 | train_loss: 1.9644 | train_acc: 0.4386 | test_loss: 1.5205 | test_acc: 0.7865\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 2/5 [00:09<00:14,  4.94s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 2 | train_loss: 1.2589 | train_acc: 0.7878 | test_loss: 1.1589 | test_acc: 0.7604\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 3/5 [00:14<00:09,  4.72s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 3 | train_loss: 0.8642 | train_acc: 0.8776 | test_loss: 0.9347 | test_acc: 0.7917\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 4/5 [00:18<00:04,  4.56s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 4 | train_loss: 0.6827 | train_acc: 0.8856 | test_loss: 0.6637 | test_acc: 0.8750\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:23<00:00,  4.65s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 5 | train_loss: 0.5688 | train_acc: 0.9069 | test_loss: 0.6175 | test_acc: 0.8854\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Train model\n",
    "# Note: Not using engine.train() since the original script isn't updated\n",
    "results = train(model=model,\n",
    "        train_dataloader=train_dataloader,\n",
    "        test_dataloader=test_dataloader,\n",
    "        optimizer=optimizer,\n",
    "        loss_fn=loss_fn,\n",
    "        epochs=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Looks like mixed precision doesn't offer much benefit for smaller feature extraction models..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Without mixed precision\n",
    "#  20%|██        | 1/5 [00:03<00:14,  3.71s/it]Epoch: 1 | train_loss: 0.5229 | train_acc: 0.9054 | test_loss: 0.5776 | test_acc: 0.8542\n",
    "#  40%|████      | 2/5 [00:07<00:11,  3.74s/it]Epoch: 2 | train_loss: 0.4699 | train_acc: 0.9001 | test_loss: 0.5160 | test_acc: 0.8802\n",
    "#  60%|██████    | 3/5 [00:10<00:07,  3.63s/it]Epoch: 3 | train_loss: 0.3913 | train_acc: 0.9196 | test_loss: 0.4888 | test_acc: 0.8906\n",
    "#  80%|████████  | 4/5 [00:14<00:03,  3.61s/it]Epoch: 4 | train_loss: 0.3724 | train_acc: 0.9371 | test_loss: 0.4931 | test_acc: 0.8698\n",
    "# 100%|██████████| 5/5 [00:18<00:00,  3.61s/it]Epoch: 5 | train_loss: 0.3315 | train_acc: 0.9381 | test_loss: 0.4405 | test_acc: 0.8750\n",
    "\n",
    "# # With mixed precision\n",
    "#  20%|██        | 1/5 [00:04<00:17,  4.40s/it]Epoch: 1 | train_loss: 0.3027 | train_acc: 0.9554 | test_loss: 0.4386 | test_acc: 0.8802\n",
    "#  40%|████      | 2/5 [00:08<00:13,  4.49s/it]Epoch: 2 | train_loss: 0.2826 | train_acc: 0.9539 | test_loss: 0.4080 | test_acc: 0.8802\n",
    "#  60%|██████    | 3/5 [00:13<00:08,  4.48s/it]Epoch: 3 | train_loss: 0.2450 | train_acc: 0.9609 | test_loss: 0.4130 | test_acc: 0.8750\n",
    "#  80%|████████  | 4/5 [00:18<00:04,  4.53s/it]Epoch: 4 | train_loss: 0.2450 | train_acc: 0.9594 | test_loss: 0.4158 | test_acc: 0.8802\n",
    "# 100%|██████████| 5/5 [00:22<00:00,  4.49s/it]Epoch: 5 | train_loss: 0.2307 | train_acc: 0.9639 | test_loss: 0.4124 | test_acc: 0.8906"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Try mixed precision with larger model\n",
    "\n",
    "Now we'll try turn on mixed precision with a larger model (e.g. EffifientNetB0 with all layers tuneable)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Unfreeze all base layers \n",
    "for param in model.features.parameters():\n",
    "    param.requires_grad = True\n",
    "\n",
    "# for param in model.features.parameters():\n",
    "#     print(param.requires_grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 1/5 [00:13<00:53, 13.27s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1 | train_loss: 0.4934 | train_acc: 0.8586 | test_loss: 0.6467 | test_acc: 0.7969\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 2/5 [00:27<00:42, 14.09s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 2 | train_loss: 0.1750 | train_acc: 0.9628 | test_loss: 1.1806 | test_acc: 0.8385\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 3/5 [00:42<00:28, 14.28s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 3 | train_loss: 0.1362 | train_acc: 0.9619 | test_loss: 0.5831 | test_acc: 0.8802\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 4/5 [00:57<00:14, 14.47s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 4 | train_loss: 0.1743 | train_acc: 0.9462 | test_loss: 0.5702 | test_acc: 0.8854\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [01:11<00:00, 14.38s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 5 | train_loss: 0.2437 | train_acc: 0.9352 | test_loss: 0.7096 | test_acc: 0.8125\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Train model\n",
    "# Note: Not using engine.train() since the original script isn't updated\n",
    "results = train(model=model,\n",
    "        train_dataloader=train_dataloader,\n",
    "        test_dataloader=test_dataloader,\n",
    "        optimizer=optimizer,\n",
    "        loss_fn=loss_fn,\n",
    "        epochs=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Without mixed precision...\n",
    "#  20%|██        | 1/5 [00:11<00:46, 11.61s/it]Epoch: 1 | train_loss: 0.4507 | train_acc: 0.8648 | test_loss: 1.0603 | test_acc: 0.7604\n",
    "#  40%|████      | 2/5 [00:24<00:36, 12.21s/it]Epoch: 2 | train_loss: 0.1659 | train_acc: 0.9464 | test_loss: 0.6398 | test_acc: 0.8490\n",
    "#  60%|██████    | 3/5 [00:36<00:24, 12.38s/it]Epoch: 3 | train_loss: 0.1261 | train_acc: 0.9698 | test_loss: 0.7149 | test_acc: 0.8542\n",
    "#  80%|████████  | 4/5 [00:49<00:12, 12.53s/it]Epoch: 4 | train_loss: 0.1250 | train_acc: 0.9609 | test_loss: 0.7441 | test_acc: 0.7917\n",
    "# 100%|██████████| 5/5 [01:02<00:00, 12.42s/it]Epoch: 5 | train_loss: 0.1282 | train_acc: 0.9564 | test_loss: 0.8701 | test_acc: 0.8385\n",
    "\n",
    "# # With mixed precision...\n",
    "#  20%|██        | 1/5 [00:13<00:53, 13.27s/it]Epoch: 1 | train_loss: 0.4934 | train_acc: 0.8586 | test_loss: 0.6467 | test_acc: 0.7969\n",
    "#  40%|████      | 2/5 [00:27<00:42, 14.09s/it]Epoch: 2 | train_loss: 0.1750 | train_acc: 0.9628 | test_loss: 1.1806 | test_acc: 0.8385\n",
    "#  60%|██████    | 3/5 [00:42<00:28, 14.28s/it]Epoch: 3 | train_loss: 0.1362 | train_acc: 0.9619 | test_loss: 0.5831 | test_acc: 0.8802\n",
    "#  80%|████████  | 4/5 [00:57<00:14, 14.47s/it]Epoch: 4 | train_loss: 0.1743 | train_acc: 0.9462 | test_loss: 0.5702 | test_acc: 0.8854\n",
    "# 100%|██████████| 5/5 [01:11<00:00, 14.38s/it]Epoch: 5 | train_loss: 0.2437 | train_acc: 0.9352 | test_loss: 0.7096 | test_acc: 0.8125"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Checking the PyTorch profiler, it seems that mixed precision utilises some Tensor Cores, however, these aren't large numbers.\n",
    "\n",
    "E.g. it uses 9-12% Tensor Cores. Perhaps the slow down when using mixed precision is because the tensors have to get altered and converted when there isn't very many of them. For example only 9-12% of tensors get converted so the speed up gains aren't realised on these tensors because they get cancelled out by the conversion time."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Extensions\n",
    "* Does changing the data input size to EfficientNetB4 change its results? E.g. input image size of (380, 380) instead of (224, 224)?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    " "
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "3fbe1355223f7b2ffc113ba3ade6a2b520cadace5d5ec3e828c83ce02eb221bf"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
